{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example selectors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "source": [
    "The base interface is defined as below:<br>\n",
    "基本接口定义如下：\n",
    "```shell\n",
    "class BaseExampleSelector(ABC):\n",
    "    \"\"\"Interface for selecting examples to include in prompts.\"\"\"\n",
    "\n",
    "    @abstractmethod\n",
    "    def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:\n",
    "        \"\"\"Select which examples to use based on the inputs.\"\"\"\n",
    "        \n",
    "    @abstractmethod\n",
    "    def add_example(self, example: Dict[str, str]) -> Any:\n",
    "        \"\"\"Add new example to store.\"\"\"\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example Selector Types\n",
    "\n",
    "| Name      |\tDescription                                 |\n",
    "|-----------|-----------------------------------------------|\n",
    "|Similarity | 使用输入和示例之间的语义相似性来决定选择哪些示例   |\n",
    "|MMR\t    | 使用输入和示例之间的最大边际相关性来决定选择哪些示例|\n",
    "|Length \t| 根据在一定长度内可以容纳的数量来选择示例           |\n",
    "|Ngram      | 使用输入和示例之间的 ngram 重叠来决定选择哪些示例  |\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom Example Selector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.example_selectors.base import BaseExampleSelector\n",
    "\n",
    "\n",
    "class CustomExampleSelector(BaseExampleSelector):\n",
    "    def __init__(self, examples):\n",
    "        self.examples = examples\n",
    "\n",
    "    def add_example(self, example):\n",
    "        self.examples.append(example)\n",
    "\n",
    "    def select_examples(self, input_variables):\n",
    "        # This assumes knowledge that part of the input will be a 'text' key\n",
    "        new_word = input_variables[\"input\"]\n",
    "        new_word_length = len(new_word)\n",
    "\n",
    "        # Initialize variables to store the best match and its length difference\n",
    "        best_match = None\n",
    "        smallest_diff = float(\"inf\")\n",
    "\n",
    "        # Iterate through each example\n",
    "        for example in self.examples:\n",
    "            # Calculate the length difference with the first word of the example\n",
    "            current_diff = abs(len(example[\"input\"]) - new_word_length)\n",
    "\n",
    "            # Update the best match if the current one is closer in length\n",
    "            if current_diff < smallest_diff:\n",
    "                smallest_diff = current_diff\n",
    "                best_match = example\n",
    "\n",
    "        return [best_match]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'input': 'bye', 'output': 'arrivaderci'}]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "examples = [\n",
    "    {\"input\": \"hi\", \"output\": \"ciao\"},\n",
    "    {\"input\": \"bye\", \"output\": \"arrivaderci\"},\n",
    "    {\"input\": \"soccer\", \"output\": \"calcio\"},\n",
    "]\n",
    "example_selector = CustomExampleSelector(examples)\n",
    "example_selector.select_examples({\"input\": \"okay\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'input': 'hand', 'output': 'mano'}]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_selector.add_example({\"input\": \"hand\", \"output\": \"mano\"})\n",
    "example_selector.select_examples({\"input\": \"okay\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use in a Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Translate the following words from English to Italain:\n",
      "\n",
      "Input: hand -> Output: mano\n",
      "\n",
      "Input: word -> Output:\n"
     ]
    }
   ],
   "source": [
    "from langchain_core.prompts.few_shot import FewShotPromptTemplate\n",
    "from langchain_core.prompts.prompt import PromptTemplate\n",
    "\n",
    "example_prompt = PromptTemplate.from_template(\"Input: {input} -> Output: {output}\")\n",
    "\n",
    "prompt = FewShotPromptTemplate(\n",
    "    example_selector=example_selector,\n",
    "    example_prompt=example_prompt,\n",
    "    suffix=\"Input: {input} -> Output:\",\n",
    "    prefix=\"Translate the following words from English to Italain:\",\n",
    "    input_variables=[\"input\"],\n",
    ")\n",
    "\n",
    "print(prompt.format(input=\"word\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langchain0_1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
